# http://proceedings.mlr.press/v101/huang19a/huang19a.pdf
# https://www.researchgate.net/publication/220875351_Generative_Models_for_Labeling_Multi-object_Configurations_in_Images
# https://www.tensorflow.org/datasets/catalog/open_images_v4
# Auto-Encoding Progressive Generative Adversarial Networks For 3D Multi Object Scenes
TODO
datasets to experiment
%config Completer.use_jedi = False
from ipywidgets import IntProgress
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import logging
import tensorflow_datasets as tfds
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from sklearn.mixture import GaussianMixture
import os
seed = 1
np.random.seed(1)
tf.random.set_seed(1)
batch_size = 32
epochs = 10
dataset_name = 'coil100'
if dataset_name == 'bdd100k':
train_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/train1/',batch_size=batch_size)# train
test_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/test1/',batch_size=batch_size) # test
validation_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/val1/',batch_size=batch_size) # validation
elif dataset_name in ['coil100']:
train_ds = tfds.load(name=dataset_name,split=['train']\
,as_supervised=False,download=True)[0]
validation_ds = test_ds = train_ds
elif dataset_name in ['flic','fashion_mnist','mnist','kitti']:
train_ds,test_ds = tfds.load(name=dataset_name,split=['train', 'test']\
,as_supervised=False,download=True)
validation_ds = test_ds
elif dataset_name in ['wider_face']:
train_ds,test_ds,validation_ds = tfds.load(name=dataset_name,split=['train', 'test','validation']\
,as_supervised=False,download=True)
else:
raise ValueError(f'Unhandled dataset {dataset_name}')
if dataset_name == 'bdd100k':
dims = [x[0].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['batch','height','width','depth'])
else:
dims = [x['image'].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['height','width','depth'])
dims_df.describe()
| height | width | depth | |
|---|---|---|---|
| count | 7200.0 | 7200.0 | 7200.0 |
| mean | 128.0 | 128.0 | 3.0 |
| std | 0.0 | 0.0 | 0.0 |
| min | 128.0 | 128.0 | 3.0 |
| 25% | 128.0 | 128.0 | 3.0 |
| 50% | 128.0 | 128.0 | 3.0 |
| 75% | 128.0 | 128.0 | 3.0 |
| max | 128.0 | 128.0 | 3.0 |
m = 1
height = int(min(dims_df['height'])/m)*m
width = int(min(dims_df['width'])/m)*m
# height = 2**(int(np.log2(min(dims_df['height']))))
# width = 2**(int(np.log2(min(dims_df['width']))))
depth = min(dims_df['depth'])
height,width = min(height,width),min(height,width)
height,width,depth
(128, 128, 3)
if dataset_name == 'bdd100k':
train_ds = train_ds.map(lambda x0,x1: x0/255.)
test_ds = test_ds.map(lambda x0,x1: x0/255.)
validation_ds = validation_ds.map(lambda x0,x1: x0/255.)
else:
train_ds = train_ds.map(lambda x: tf.image.resize(images=tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
train_ds = train_ds.batch(batch_size,drop_remainder=True)
###
test_ds = test_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
test_ds = test_ds.batch(batch_size,drop_remainder=True)
###
validation_ds = validation_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.\
,size=[height,width]))
validation_ds = validation_ds.batch(batch_size,drop_remainder=True)
###
train_ds_double_zipped = tf.data.Dataset.zip(datasets=(train_ds,train_ds))
test_ds_double_zipped = tf.data.Dataset.zip(datasets=(test_ds,test_ds))
validation_ds_double_zipped = tf.data.Dataset.zip(datasets=(validation_ds,validation_ds))
latent_dim = 1024
class CAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CAE, self).__init__()
self.latent_dim = latent_dim
self.logger = logging.getLogger('CAE')
self.encoder = tf.keras.Sequential(name='encoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(height, width, depth)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim),
]
)
self.decoder = tf.keras.Sequential(name='decoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=int(height/4) * int(width/4) * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(int(height/4), int(width/4), 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=depth, kernel_size=3, strides=1, padding='same'),
]
)
self.encoder.summary()
self.decoder.summary()
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
cae = CAE(latent_dim)
cae.compile(optimizer='adam', loss=losses.MeanSquaredError())
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 63, 63, 32) 896 _________________________________________________________________ conv2d_1 (Conv2D) (None, 31, 31, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 61504) 0 _________________________________________________________________ dense (Dense) (None, 1024) 62981120 ================================================================= Total params: 63,000,512 Trainable params: 63,000,512 Non-trainable params: 0 _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 32768) 33587200 _________________________________________________________________ reshape (Reshape) (None, 32, 32, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 64, 64, 64) 18496 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 128, 128, 32) 18464 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 128, 128, 3) 867 ================================================================= Total params: 33,625,027 Trainable params: 33,625,027 Non-trainable params: 0 _________________________________________________________________
model_file_path = f'./models/cae_dataset_{dataset_name}_z_dim_{latent_dim}_data_dim_{height}x{width}x{depth}'
print(f'model path = {model_file_path}')
model path = ./models/cae_dataset_coil100_z_dim_1024_data_dim_128x128x3
if os.path.exists(model_file_path):
print('loading saved model')
cae = tf.keras.models.load_model(filepath=model_file_path)
else:
print('building model')
# use checkpoints to save model fitting progress
# https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
checkpoint_filepath = './checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
cae.fit(x=train_ds_double_zipped,validation_data=test_ds_double_zipped,epochs=epochs,\
callbacks=[model_checkpoint_callback],use_multiprocessing=True)
# The model weights (that are considered the best) are loaded into the model.
cae.load_weights(checkpoint_filepath)
print('saving model')
cae.save(filepath=model_file_path)
building model Epoch 1/10 225/225 [==============================] - 158s 697ms/step - loss: 0.0352 - val_loss: 0.0099 Epoch 2/10 225/225 [==============================] - 148s 660ms/step - loss: 0.0084 - val_loss: 0.0058 Epoch 3/10 225/225 [==============================] - 148s 660ms/step - loss: 0.0055 - val_loss: 0.0041 Epoch 4/10 225/225 [==============================] - 149s 664ms/step - loss: 0.0044 - val_loss: 0.0033 Epoch 5/10 225/225 [==============================] - 150s 665ms/step - loss: 0.0035 - val_loss: 0.0029 Epoch 6/10 225/225 [==============================] - 1846s 8s/step - loss: 0.0030 - val_loss: 0.0027 Epoch 7/10 225/225 [==============================] - 191s 852ms/step - loss: 0.0027 - val_loss: 0.0026 Epoch 8/10 225/225 [==============================] - 157s 700ms/step - loss: 0.0025 - val_loss: 0.0022 Epoch 9/10 225/225 [==============================] - 167s 742ms/step - loss: 0.0023 - val_loss: 0.0020 Epoch 10/10 225/225 [==============================] - 164s 731ms/step - loss: 0.0021 - val_loss: 0.0018 saving model INFO:tensorflow:Assets written to: ./models/cae_dataset_coil100_z_dim_1024_data_dim_128x128x3/assets
INFO:tensorflow:Assets written to: ./models/cae_dataset_coil100_z_dim_1024_data_dim_128x128x3/assets
# create valdation dataset tensor
for e in validation_ds.take(1):
initial_state = tf.zeros(dtype=tf.float32,shape=e.shape)
validation_ds_tensor = validation_ds.\
reduce(initial_state=initial_state,reduce_func=lambda x,y: tf.concat(values=[x,y],axis=0))
validation_ds_tensor = validation_ds_tensor[batch_size:] # drop dummy initial state
# calculate loss, can be compare over different dataset due to data scaling from 0 to 1
y_predicted = cae.predict(validation_ds)
cae_loss = cae.loss(y_pred=y_predicted,y_true=validation_ds_tensor).numpy()
print(f'CAE loss for dataset {dataset_name} = {np.round(cae_loss,4)}')
CAE loss for dataset coil100 = 0.00989999994635582
# plot decoded images
for batch in validation_ds.take(2):
z = cae.encoder(batch).numpy()
decoded_imgs = cae.decoder(z).numpy()
for i in range(batch.shape[0]):
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(batch[i])
ax2.imshow(decoded_imgs[i])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# getting z tensor
z_tensor = None
inf_or_unknown_cardinality = ((test_ds.cardinality()==tf.data.INFINITE_CARDINALITY)\
or (test_ds.cardinality() == tf.data.UNKNOWN_CARDINALITY)).numpy()
batches = test_ds.cardinality().numpy() if not inf_or_unknown_cardinality else 500
with tqdm(total=batches) as pbar:
for batch in test_ds.take(batches):
z = cae.encoder(batch).numpy()
if z_tensor is None:
z_tensor = tf.convert_to_tensor(z)
else:
z_tensor = tf.concat([z_tensor,tf.convert_to_tensor(z)],axis=0)
pbar.update(1)
#print(f'z shape {z.shape}')
# decoded_imgs = cae.decoder(z).numpy()
# #print(f'decoded images shape {decoded_imgs[0].shape}')
# plt.imshow(batch[0])
# plt.show()
# plt.imshow(decoded_imgs[0])
# plt.show()
z_tensor.shape
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
TensorShape([7200, 1024])
z_np= z_tensor.numpy()
n_z = z_np.shape[0]
n_z_train = int(0.8*n_z)
z_train = z_np[:n_z_train]
z_test = z_np[n_z_train:]
random_state = 1
reg_covar = 0.1
logps = []
k_values = [1,10,20,50,70,80,100,200]
cov_types = ['diag','cov']
for k in k_values:
for cov_type in ['diag','full']:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
print(f'For Gaussin Mixture with k = {k} and cov type {cov_type}, logp = {logp_gm} ')
logps.append({'k':k,'cov_type':cov_type,'logp':logp_gm})
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Gaussin Mixture with k = 1 and cov type diag, logp = -397.24901244255807 ############## For Gaussin Mixture with k = 1 and cov type full, logp = 155.36384970335794 ############## For Gaussin Mixture with k = 10 and cov type diag, logp = -210.10651545434658 ############## For Gaussin Mixture with k = 10 and cov type full, logp = 182.3010500637394 ############## For Gaussin Mixture with k = 20 and cov type diag, logp = -146.7904417674913 ############## For Gaussin Mixture with k = 20 and cov type full, logp = 189.89825966647012 ############## For Gaussin Mixture with k = 50 and cov type diag, logp = -37.50708497014304 ############## For Gaussin Mixture with k = 50 and cov type full, logp = 199.7700446300683 ############## For Gaussin Mixture with k = 70 and cov type diag, logp = -1.7764759001879964 ############## For Gaussin Mixture with k = 70 and cov type full, logp = 203.65747828192292 ############## For Gaussin Mixture with k = 80 and cov type diag, logp = 15.425406759482902 ############## For Gaussin Mixture with k = 80 and cov type full, logp = 204.98214107417274 ############## For Gaussin Mixture with k = 100 and cov type diag, logp = 36.45892575502282 ############## For Gaussin Mixture with k = 100 and cov type full, logp = 207.06634715741455 ############## For Gaussin Mixture with k = 200 and cov type diag, logp = 93.30223911827815 ############## For Gaussin Mixture with k = 200 and cov type full, logp = 212.15875086231188 ##############
logps_df = pd.DataFrame.from_records(data=logps)
logps_df.sort_values(by='logp',ascending=False).reset_index()
| index | k | cov_type | logp | |
|---|---|---|---|---|
| 0 | 15 | 200 | full | 212.158751 |
| 1 | 13 | 100 | full | 207.066347 |
| 2 | 11 | 80 | full | 204.982141 |
| 3 | 9 | 70 | full | 203.657478 |
| 4 | 7 | 50 | full | 199.770045 |
| 5 | 5 | 20 | full | 189.898260 |
| 6 | 3 | 10 | full | 182.301050 |
| 7 | 1 | 1 | full | 155.363850 |
| 8 | 14 | 200 | diag | 93.302239 |
| 9 | 12 | 100 | diag | 36.458926 |
| 10 | 10 | 80 | diag | 15.425407 |
| 11 | 8 | 70 | diag | -1.776476 |
| 12 | 6 | 50 | diag | -37.507085 |
| 13 | 4 | 20 | diag | -146.790442 |
| 14 | 2 | 10 | diag | -210.106515 |
| 15 | 0 | 1 | diag | -397.249012 |